/*
 Copyright 2013--Present JMM_PROGNAME
 
 This file is distributed under the terms of the JMM_PROGNAME License.
 
 You should have received a copy of the JMM_PROGNAME License.
 If not, see <JMM_PROGNAME WEBSITE>.
*/
// CREATED    : 9/15/2015
// LAST UPDATE: 9/30/2015

#include "statsxx/machine_learning/neural_network/deep_belief_network/DBN.hpp"

// STL
#include <fstream>  // std::ofstream()
//#include <iostream> // std::cout, std::endl
#include <vector>   // std::vector<>
#include <utility>  // std::pair<>, std::make_pair()

// jScience
#include "jScience/linalg/Matrix.hpp" // Matrix<>

/*
// jNeuralNet
#include "datasets.hpp" // DataSet, partition_data_set()
*/
 
// stats++
#include "statsxx/machine_learning/restricted_Boltzmann_machine/RBM.hpp"  // RBM


// get the weight matrix and bias vector (for the entire DBN)
// << jmm: this may be a temporary routine; see TODO in the CHANGELOG >>
// note: the FULL weight matrix and bias vector (including output nodes) is returned
// note: the full weight matrix is returned, but only populated in lower-triangular form (it is obvioulsy symmetric)
// note: weights connecting to the output nodes are (obviously) zero, since the output nodes do not know about them
// note: it is assumed that it is the hidden biases that are passed to the MLP
inline std::pair<
                 Matrix<double>,
                 Vector<double>
                 > neural_network::DBN::get_RBM_weights()
{
    Matrix<double> W;
    Vector<double> bias;
    
    int nn = sum(this->architecture);
    
    W = Matrix<double>(nn, nn, 0.0);
    bias = Vector<double>(nn, 0.0);
    
    int cnt = 0;
    
    // note: recall that the weight matrix for each RBM is stored as [nhid x nvis]
    for(auto i = 0; i < this->RBM.size(); ++i)
    {
        // for each visible neuron [globally numbered cnt to (cnt + nvis - 1)] ...
        for(auto j = 0; j < this->RBM[i].W.size(1); ++j)
        {
            // for each hidden neuron [globally numbered (cnt + nvis) to (cnt + nvis + nhid - 1)]
            for(auto k = 0; k < this->RBM[i].W.size(0); ++k)
            {
                // note: recall that we only populate the matrix in lower-triangular form
//                W((cnt+j), (cnt+this->RBM[i].W.size(1)+k)) = this->RBM[i].W(k,j);
                W((cnt+this->RBM[i].W.size(1)+k), (cnt+j)) = this->RBM[i].W(k,j);
                
//                std::cout << "this->RBM[i].W(k,j): " << this->RBM[i].W(k,j) << '\n';
            }
        }
        
        // see the note above
        for(auto k = 0; k < this->RBM[i].W.size(0); ++k)
        {
            bias((cnt+this->RBM[i].W.size(1)+k)) = this->RBM[i].c(k);
        }
        
        cnt += this->RBM[i].W.size(1);
    }


    
    // *****
/*
    std::cout << '\n' << '\n';

    cnt = 0;
    
    for(auto i = 0; i < this->RBM.size(); ++i)
    {
        std::cout << "RBM " << i << '\n';
        std::cout << "this->RBM[" << i << ".W.size(1): " << this->RBM[i].W.size(1) << "(" << this->RBM[i].nvis << ")" << '\n';
        std::cout << "this->RBM[" << i << ".W.size(0): " << this->RBM[i].W.size(0) << "(" << this->RBM[i].nhid << ")" << '\n';
        std::cout << "-----" << '\n';
        
        
        // for each visible neuron [globally numbered cnt to (cnt + nvis - 1)] ...
        for(auto j = 0; j < this->RBM[i].W.size(1); ++j)
        {
            // for each hidden neuron [globally numbered (cnt + nvis) to (cnt + nvis + nhid - 1)]
            for(auto k = 0; k < this->RBM[i].W.size(0); ++k)
            {
                std::cout << j << "(" << (cnt+j) << ") --> " << k << "(" << (cnt+this->RBM[i].W.size(1)+k) << "): " << this->RBM[i].W(k,j) << '\n';
            }
        }
        
        std::cout << '\n';
        
        cnt += this->RBM[i].W.size(1);
    }
 
    for(auto i = 0; i < bias.size(); ++i)
    {
        std::cout << "bias(" << i << "): " << bias(i) << '\n';
    }
    
    std::cout << '\n' << '\n';
*/
    // *****
    
    
    
    
/*
    std::ofstream ofs("test.dat", std::ios::out);
    
    for(auto i = 0; i < W.size(0); ++i)
    {
        for(auto j = 0; j < W.size(0); ++j)
        {
//            std::cout << W(i,j) << "  ";
            ofs << W(i,j) << "  ";
//            std::cout << "W(" << i << "," << j << "): " << W(i,j);
        }
        
//        std::cout << '\n';
        ofs << '\n';
    }
 
    ofs.close();
*/
    
    return std::make_pair(W, bias);
}
